% ***********************************************************************
% *                          Simulation 3                               *    
% ***********************************************************************

%% =============== Multi-Neuron Compound AP Simulation   ================
%
% Simulates compound action potentials (CAPs) from three unmyelinated
% neuron types with user-defined parameters. Calculates conduction 
% velocities, generates total CAP and individual neuron AP plots, and
% visualises how varying neuron properties affect signal propagation.

% Author: Oliver Foley
% Date: 29/04/2025

%% ====================== INITIALISE PARAMETERS ======================

% === INPUT PARAMETERS ===
tmax = 40;                        % Simulation time (ms)
dt = 0.05;                        % Time step (ms)
dx = 0.01;                        % Space step (cm)
axon_length = 1;                 % Axon length (cm)
t = 0:dt:tmax;                   % Time vector (ms)
Nx = axon_length / dx;           % Number of spatial points
Nt = length(t);                  % Number of time steps

% === STIMULUS SETTINGS ===
stim_start = 5;                  % Stimulus start time (ms)
stim_dur = 5;                    % Stimulus duration (ms)
stim_end = stim_start + stim_dur;% Stimulus end time (ms)
stim_amp = 1 / 1000;             % Stimulus amplitude (uA → mA)
stim_loc = 0.01;                 % Stimulus location on axon (cm)
stim_index = round(Nx * stim_loc / axon_length); % Stimulus index on grid

% === RECORDING ELECTRODE SETTINGS ===
recording_mode = 'monophasic';   % Options: 'monophasic' or 'biphasic'
pos1_percentage = 10;           % + electrode 1 (% of axon length)
neg1_percentage = pos1_percentage + 8; % - electrode 1
pos2_percentage = 90;           % + electrode 2
neg2_percentage = pos2_percentage + 8; % - electrode 2
pos_idx1 = round(Nx * pos1_percentage / 100); % Index + electrode 1
neg_idx1 = round(Nx * neg1_percentage / 100); % Index - electrode 1
pos_idx2 = round(Nx * pos2_percentage / 100); % Index + electrode 2
neg_idx2 = round(Nx * neg2_percentage / 100); % Index - electrode 2

% === NEURON TYPES SETUP ===
neuron_types(1) = struct('name', 'Type 1', 'count', 50, 'Cm', 1, 'axon_radius', 0.0023/10, 'Ra', 0.15);
neuron_types(2) = struct('name', 'Type 2', 'count', 20, 'Cm', 1, 'axon_radius', 0.0020/10, 'Ra', 0.16);
neuron_types(3) = struct('name', 'Type 3', 'count', 20, 'Cm', 1, 'axon_radius', 0.0017/10, 'Ra', 0.17);

% === INITIALISE STORAGE ===
CAP1 = zeros(Nt, 1);                      % Compound AP at electrode pair 1
CAP2 = zeros(Nt, 1);                      % Compound AP at electrode pair 2
CAP_per_type1 = zeros(Nt, length(neuron_types)); % APs per neuron type (recording 1)
CAP_per_type2 = zeros(Nt, length(neuron_types)); % APs per neuron type (recording 2)
conduction_velocities = NaN(1, length(neuron_types)); % Velocity placeholders

%% ======================== RUN SIMULATION =========================

% Loop through each neuron type and simulate propagation
for i = 1:length(neuron_types)
    type = neuron_types(i);

    % Run single axon model for current neuron type
    [~, signal1, signal2] = runSingleAxon(t, Nx, type.Cm, type.axon_radius, type.Ra, stim_index, stim_start, stim_end, stim_amp, recording_mode, pos1_percentage, neg1_percentage, pos2_percentage, neg2_percentage);

    % Sum weighted APs to build compound signal
    CAP1 = CAP1 + type.count * signal1;
    CAP2 = CAP2 + type.count * signal2;

    % Store individual traces for plotting
    CAP_per_type1(:,i) = signal1;
    CAP_per_type2(:,i) = signal2;

    % Estimate conduction velocity from delay between peak locations
    [~, locs1] = findpeaks(signal1, 'NPeaks', 1, 'MinPeakHeight', max(signal1) * 0.2);
    [~, locs2] = findpeaks(signal2, 'NPeaks', 1, 'MinPeakHeight', max(signal2) * 0.2);
    if ~isempty(locs1) && ~isempty(locs2)
        delta_t = (t(locs2(1)) - t(locs1(1))) * 1e-3; % Time delay (s)
        distance = (pos2_percentage/100 - pos1_percentage/100) / 100 * axon_length; % Distance (cm)
        conduction_velocities(i) = distance / delta_t; % Velocity (cm/s)
    end
end

%% =========================== PLOTTING ============================

% Plot combined CAP at both recording locations
figure('Color', 'w', 'Position', [100 100 700 800]);
subplot(4,1,1);
plot(t, CAP1, 'r', t, CAP2, 'b', 'LineWidth', 1.5);
title('Nerve Compound Action Potential Recording');
ylabel('Summed Voltage (mV)');
xlabel('Time (ms)');

legend({'Electrode 1: 1 mm', 'Electrode 2: 9 mm'}, ...
       'Location', 'best', 'FontSize', 10);
grid on;

% Plot action potentials per neuron type with velocity estimates
for i = 1:3
    subplot(4,1,i+1);
    plot(t, CAP_per_type1(:,i), 'r', t, CAP_per_type2(:,i), 'b--', 'LineWidth', 1.5);
    if ~isnan(conduction_velocities(i))
        title(sprintf('%s Neuron Action Potential Recording (Conduction Velocity: %.2f m/s)', neuron_types(i).name, conduction_velocities(i)));
    else
        title(sprintf('%s Neuron Action Potential Recording (Conduction Velocity: N/A)', neuron_types(i).name));
    end
    ylabel('Membrane Voltage (mV)');
    grid on;
    xlabel('Time (ms)');
    
end

%% =================== HELPER FUNCTION: Single Axon ====================
function [t, Signal1, Signal2] = runSingleAxon(t, Nx, Cm, axon_radius, Ra, stim_index, stim_start, stim_end, stim_amp, mode, pos1_pct, neg1_pct, pos2_pct, neg2_pct)
    % === MODEL CONSTANTS ===
    dx = 0.01; dt = t(2)-t(1);
    gNa_max = 120; gK_max = 36; gCl = 0.0667;
    ENa = 56; EK = -77; ECl = -68;

    % Convert percentage locations to spatial indices
    pos1 = round(Nx * pos1_pct / 100);
    neg1 = round(Nx * neg1_pct / 100);
    pos2 = round(Nx * pos2_pct / 100);
    neg2 = round(Nx * neg2_pct / 100);

    % === INITIALISE ===
    V = ECl * ones(Nx, 1);            % Initial voltage
    I = zeros(Nx, 1);                 % External current vector
    n = an(V) ./ (an(V) + bn(V));     % Initial gating
    m = am(V) ./ (am(V) + bm(V));
    h = ah(V) ./ (ah(V) + bh(V));

    % Construct spatial coupling matrix (finite difference form)
    e = ones(Nx,1);
    B = spdiags([-e 2*e -e], -1:1, Nx, Nx) / dx^2;
    B(1,1) = 1/dx^2; B(end,end) = 1/dx^2;
    B = (axon_radius / (2 * Ra)) * B;
    dB = diag(B);

    % Initialise output signals
    Signal1 = zeros(length(t), 1);
    Signal2 = zeros(length(t), 1);

    % === SOLVE MODEL OVER TIME ===
    for j = 1:length(t)
        t_now = t(j);

        % Apply stimulus to stim_index
        I(stim_index) = (t_now > stim_start && t_now < stim_end) * (stim_amp / (2 * pi * axon_radius * dx));

        % Update gating variables
        a_n = an(V); c_n = (a_n + bn(V)) / 2;
        n = ((1/dt - c_n) .* n + a_n) ./ (1/dt + c_n);
        a_m = am(V); c_m = (a_m + bm(V)) / 2;
        m = ((1/dt - c_m) .* m + a_m) ./ (1/dt + c_m);
        a_h = ah(V); c_h = (a_h + bh(V)) / 2;
        h = ((1/dt - c_h) .* h + a_h) ./ (1/dt + c_h);

        % Compute new voltage via half-step backward Euler
        m3h = m.^3 .* h; n4 = n.^4;
        d = gNa_max * m3h + gK_max * n4 + gCl;
        f = gNa_max * m3h * ENa + gK_max * n4 * EK + gCl * ECl + I;
        B(1:Nx+1:end) = dB + d + 2*Cm/dt;
        Vmid = B \ (2*Cm*V/dt + f);
        V = 2 * Vmid - V;

        % Extract signals at specified electrodes
        if strcmp(mode, 'biphasic')
            Signal1(j) = V(pos1) - V(neg1);
            Signal2(j) = V(pos2) - V(neg2);
        else
            Signal1(j) = V(pos1);
            Signal2(j) = V(pos2);
        end
    end
end
 

%% ========================= GATING FUNCTIONS =========================
function val = an(V)
    val = 0.01 * (10 - (V + 71)) ./ (exp(1 - (V + 71) / 10) - 1); % alpha_n
end

function val = bn(V)
    val = 0.125 * exp(-(V + 71) / 80);                             % beta_n
end

function val = am(V)
    val = 0.1 * (25 - (V + 71)) ./ (exp(2.5 - (V + 71) / 10) - 1); % alpha_m
end

function val = bm(V)
    val = 4 * exp(-(V + 71) / 18);                                 % beta_m
end

function val = ah(V)
    val = 0.07 * exp(-(V + 71) / 20);                              % alpha_h
end

function val = bh(V)
    val = 1 ./ (exp(3 - (V + 71) / 10) + 1);                       % beta_h
end
